import utils.config as config
import numpy as np
from pypower.idx_brch import BR_STATUS
from pypower.makeYbus import makeYbus
import math

def gen_bound(Pred_Pg,Pred_Qg):

    # 广播配置参数到 (180, 3)
    new_PG_Upbound = np.tile(config.PG_Upbound, (Pred_Pg.shape[0], 1))      
    new_PG_Lowbound = np.tile(config.PG_Lowbound, (Pred_Pg.shape[0], 1))  
    new_QG_Upbound = np.tile(config.QG_Upbound, (Pred_Qg.shape[0], 1))       
    new_QG_Lowbound = np.tile(config.QG_Lowbound, (Pred_Qg.shape[0], 1))   

    # 遍历每个测试场景并修改上界和下界
    for i in range(config.branch_test_id.shape[0]):
        branch_id = config.branch_test_id[i]
        
        if branch_id > 0:
            # 保持不变
            continue
        else:
            # 计算发电机索引，假设 branch_id 为负数
            gen_idx = -branch_id - 1  # 0-based index
            
            # 修改对应的 PG 和 QG 上下界
            new_PG_Upbound[i, gen_idx] = config.Pg_Upbound_Gen_Err
            new_PG_Lowbound[i, gen_idx] = config.Pg_Upbound_Gen_Err*(-1)
            new_QG_Upbound[i, gen_idx] = config.Qg_Upbound_Gen_Err
            new_QG_Lowbound[i, gen_idx] = config.Qg_Upbound_Gen_Err*(-1)
    return new_PG_Upbound,new_PG_Lowbound,new_QG_Upbound,new_QG_Lowbound

# Pg Qg violation
def get_PQ_violation(Pred_Pg, Pred_Qg):

    gen_PG_Upbound,gen_PG_Lowbound,gen_QG_Upbound,gen_QG_Lowbound = gen_bound(Pred_Pg,Pred_Qg)

    #PG violation
    PG_violation_up = Pred_Pg -gen_PG_Upbound
    
    PG_violation_up[PG_violation_up < config.DELTA] = 0

    PG_violation_low =  gen_PG_Lowbound - Pred_Pg
    PG_violation_low[PG_violation_low < config.DELTA] = 0

    PG_violation = PG_violation_up + PG_violation_low

    # 获取不为零的索引
    nonzero_indices = np.nonzero(PG_violation_up)

    PG_violation_ratio = np.count_nonzero(PG_violation) / (np.shape(PG_violation)[0] * np.shape(PG_violation)[1])
    PG_violation_gen = np.sum(PG_violation, axis = 1)


    # QG violation
    QG_violation_up = Pred_Qg - gen_QG_Upbound
    QG_violation_up[QG_violation_up < config.DELTA] = 0

    QG_violation_low = gen_QG_Lowbound - Pred_Qg
    QG_violation_low[QG_violation_low < config.DELTA] = 0

    QG_violation = QG_violation_up + QG_violation_low

    QG_violation_ratio = np.count_nonzero(QG_violation) / (np.shape(QG_violation)[0] * np.shape(QG_violation)[1])
    QG_violation_gen = np.sum(QG_violation, axis = 1)

    #***** indicator function for PQ violation
    PG_violation_up[PG_violation_up < 0] = 0
    PG_violation_up[PG_violation_up > 0] = 1
    PG_violation_low[PG_violation_low < 0] = 0
    PG_violation_low[PG_violation_low > 0] = 1
    QG_violation_up[QG_violation_up < 0] = 0
    QG_violation_up[QG_violation_up > 0] = 1
    QG_violation_low[QG_violation_low < 0] = 0
    QG_violation_low[QG_violation_low > 0] = 1
    PG_violation_indicator = PG_violation_up + PG_violation_low
    QG_violation_indicator  = QG_violation_up + QG_violation_low
    PQ_violation_indicator  = np.concatenate((PG_violation_indicator , QG_violation_indicator ), axis = 1)

    PQ_violation_gen = PG_violation_gen + QG_violation_gen
    PQ_violation_num = np.count_nonzero(PQ_violation_gen)
    PQG_violation_index = np.where(PQ_violation_gen > 0)

    return PG_violation_ratio, QG_violation_ratio, PG_violation_gen, QG_violation_gen, PG_violation, QG_violation, PQ_violation_num, PQG_violation_index


def get_branch_violation(V):
    branch_temp = config.branch.copy()
    test_num = np.shape(V)[0]
    branch_violation = []
    branch_violation_count = 0
    branch_violation_ratio = 0
    for i in range(test_num):
        branch_temp[:,BR_STATUS] = config.branch[:,BR_STATUS]
        if config.branch_test_id[i]>0:
            branch_temp[config.branch_test_id[i]-1,BR_STATUS] = 0

        Ybus, Yf, Yt = makeYbus(config.baseMVA, config.bus, branch_temp)

        volt = V[i]

        Vf = volt[branch_temp[:, 0].astype(int)]
        If = Yf.dot(volt).conj()
        Ff = np.multiply(If, Vf)*config.baseMVA

        Vt = volt[branch_temp[:, 1].astype(int)]
        It = Yt.dot(volt).conj()
        Ft = np.multiply(It, Vt)*config.baseMVA
        ctol = 5e-3

        Branch_index_bound = np.where(branch_temp[:,5]!=0)[0]
        RX_index = np.where(branch_temp[:,BR_STATUS]!=0)[0]


        Branch_index = np.intersect1d(RX_index, Branch_index_bound)

    
        Branch_bound = branch_temp[Branch_index, 5] + ctol
        Ff = Ff.T
        Ff_violation = np.abs(Ff[Branch_index]) - Branch_bound
        Ff_violation[Ff_violation < 0] = 0

        Ft = Ft.T
        Ft_violation = np.abs(Ft[Branch_index]) - Branch_bound
        Ft_violation[Ft_violation < 0] = 0
        Ff_penalty = np.abs(Ff_violation)
        Ft_penalty = np.abs(Ft_violation)
        Branch_penalty = Ff_penalty + Ft_penalty
        branch_violation.append(Branch_penalty)

        if (np.sum(Branch_penalty) > 1e-4):
            branch_violation_count = branch_violation_count + np.count_nonzero(Branch_penalty)
            branch_violation_ratio = np.count_nonzero(Branch_penalty) / np.size(Branch_bound)

    branch_violation_ratio = branch_violation_ratio / (test_num) * 100

    return branch_violation_ratio


def get_V_violation(Pred_VM, Pred_Va):
    # PG violation
    Pred_VM = Pred_VM.copy()
    VM_Up_violation = Pred_VM - config.VmUb
    VM_Up_violation[VM_Up_violation < config.DELTA] = 0
    VM_Up_violation_bus = np.sum(VM_Up_violation, axis=1)

    VM_Low_violation = config.VmLb - Pred_VM
    VM_Low_violation[VM_Low_violation < config.DELTA] = 0
    VM_Low_violation_bus = np.sum(VM_Low_violation, axis=1)

    VM_violation_bus = VM_Up_violation_bus + VM_Low_violation_bus
    VM_violation = VM_Low_violation + VM_Up_violation
    VM_violation_num = np.count_nonzero(VM_violation)
    VM_violation_ratio = (VM_violation_num / (Pred_VM.shape[0]*Pred_VM.shape[1])) * 100

    Pred_VA = Pred_Va*180/math.pi

    VA_violation_ratio = 0
    for i in range(Pred_VA.shape[0]):
        Pred_branch_angle = Pred_VA[i, config.branch[:, 0].astype(int)] - Pred_VA[i, config.branch[:, 1].astype(int)]
        VA_Up_violation = Pred_branch_angle - config.VA_Upbound
        VA_Up_violation[VA_Up_violation < config.DELTA] = 0
        VA_Up_violation_bus = np.sum(VA_Up_violation)
        VA_Low_violation = config.VA_Lowbound - Pred_branch_angle
        VA_Low_violation[VA_Low_violation < config.DELTA] = 0
        VA_Low_violation_bus = np.sum(VA_Low_violation)
        VA_violation_bus = VA_Up_violation_bus + VA_Low_violation_bus
        VA_violation = VA_Low_violation + VA_Up_violation
        VA_violation_num = np.count_nonzero(VA_violation)
        VA_violation_ratio = VA_violation_ratio + ((VA_violation_num / (Pred_branch_angle.shape[0])) * 100)
    VA_violation_ratio = VA_violation_ratio / Pred_VA.shape[0]

    return VM_violation_ratio, VA_violation_ratio, VM_violation_bus, VA_violation_bus